11840. Sum of Squares with Segment Tree

 

Segment trees are extremely useful. In particular "Lazy Propagation" (for example allows one to compute sums over a range in O(lg(n)) and update ranges in O(lg(n)) as well. In this problem you will compute something much harder: 

The sum of squares over a range with range updates of 2 types:

1) increment in a range

2) set all numbers the same in a range.

 

Input. There will be t (t ≤ 25) test cases in the input file. First line of the input contains two positive integers, n (n ≤ 100,000) and q (q ≤ 100,000). The next line contains n integers, each at most 1000. Each of the next q lines starts with a number, which indicates the type of operation:

2 st nd – return the sum of the squares of the numbers with indices in [st, nd] {i.e., from st to nd inclusive} (1 ≤ stndn).

1 st nd x – add "x" to all numbers with indices in [st, nd] (1 ≤ stndn, and -1,000 ≤ x ≤ 1,000).

0 st nd x – set all numbers with indices in [st, nd] to "x" (1 ≤ stndn, and -1,000 ≤ x ≤ 1,000).

 

Output. For each test case output the “Case <caseno>:” in the first line and from the second line output the sum of squares for each operation of type 2.  Intermediate overflow will not occur with proper use of 64-bit signed integer.

 

Sample Input

2

4 5

1 2 3 4

2 1 4

0 3 4 1

2 1 4

1 3 4 1

2 1 4

1 1

1

2 1 1

 

Sample Output

Case 1:

30

7

13

Case 2:

1

 

 

ÐÅØÅÍÈÅ

ñòðóêòóðû äàííûõäåðåâî îòðåçêîâ

 

Àíàëèç àëãîðèòìà

 çàäà÷å ñëåäóåò ðåàëèçîâàòü äâå ìíîæåñòâåííûå îïåðàöèè: ñëîæåíèå è ïðèñâàèâàíèå.  êàæäîé âåðøèíå äåðåâà îòðåçêîâ îáúÿâèì äâå ïåðåìåííûå add è set äëÿ õðàíåíèÿ èíôîðìàöèè ïî îòëîæåííûì îïåðàöèÿì. È ñîîòâåòñòâåííî ïðè ïðîòàëêèâàíèè (îïåðàöèè push) îáðàáàòûâàåì èõ îòäåëüíî. Åøå ñëåäóåò ðåàëèçîâàòü ïîääåðæêó ñóììû êâàäðàòîâ íà îòðåçêå.

Ðàññìîòðèì îòðåçîê [i; j] ñ ÷èñëàìè ai, …, aj. Ïóñòü êî âñåì ÷èñëàì îòðåçêà äîáàâëåíî ÷èñëî v. Ñóììà íà îòðåçêå óâåëè÷èòñÿ íà (ji + 1) * v. Ðàññìîòðèì íà ñêîëüêî óâåëè÷èòñÿ ñóììà êâàäðàòîâ íà îòðåçêå. Ïîñëå óâåëè÷åíèÿ ÷èñåë íà v êâàäðàòû íà îòðåçêå ñòàíóò ðàâíûìè (ai + v)2, (ai+1 + v)2, …, (aj + v)2. Èõ ñóììà ðàâíà ( +  + … + ) + 2 * v * (ai + …+ aj) +  (ji + 1) * v2. Òî åñòü ïðè äîáàâëåíèè v êî âñåì ÷èñëàì îòðåçêà ê òåêóùåé ñóììå êâàäðàòîâ ñëåäóåò äîáàâèòü 2 * v * (ñóììà ÷èñåë íà îòðåçêå) +  (ji + 1) * v2. Ïîýòîìó âìåñòå ñ ïîääåðæêîé ñóììû êâàäðàòîâ íà îòðåçêå ñëåäóåò òàêæå ïîääåðæèâàòü è ñóììó íà îòðåçêå.

 

Ðåàëèçàöèÿ àëãîðèòìà

 

#include <cstdio>

#include <algorithm>

#define MAX 100010

#define NORMAL 0

#define ADD 1

#define SET 2

using namespace std;

 

struct node

{

  long long sum, sumSq, type, add;

} SegTree[4*MAX];

 

long long mas[MAX];

 

void build(long long *a, int Vertex, int Left, int Right)

{

  SegTree[Vertex].type = NORMAL;

  SegTree[Vertex].add = 0;

  if (Left == Right)

  {

    SegTree[Vertex].sum = a[Left];

    SegTree[Vertex].sumSq = 1LL * a[Left] * a[Left];

  }

  else

  {

    int Middle = (Left + Right) / 2;

 

    build (a, 2*Vertex, Left, Middle);

    build (a, 2*Vertex+1, Middle+1, Right);

 

    SegTree[Vertex].sum = SegTree[2*Vertex].sum + SegTree[2*Vertex+1].sum;

    SegTree[Vertex].sumSq =

      SegTree[2*Vertex].sumSq + SegTree[2*Vertex+1].sumSq;

  }

}

 

void Push(int Vertex, int LeftPos, int Middle, int RightPos)

{

  if (SegTree[Vertex].type == SET)

  {

    SegTree[2*Vertex].type = SegTree[2*Vertex+1].type = SegTree[Vertex].type;

    SegTree[2*Vertex].add = SegTree[2*Vertex+1].add = SegTree[Vertex].add;

 

    SegTree[2*Vertex].sum = (Middle - LeftPos + 1) * SegTree[Vertex].add;

    SegTree[2*Vertex].sumSq = (Middle - LeftPos + 1) *

                               SegTree[Vertex].add * SegTree[Vertex].add;

 

    SegTree[2*Vertex+1].sum = (RightPos - Middle) * SegTree[Vertex].add;

    SegTree[2*Vertex+1].sumSq = (RightPos - Middle) *

                                 SegTree[Vertex].add * SegTree[Vertex].add;

 

    SegTree[Vertex].add = 0;

    SegTree[Vertex].type = NORMAL;

  }

 

  if (SegTree[Vertex].type == ADD)

  {

    SegTree[2*Vertex].add += SegTree[Vertex].add;

    SegTree[2*Vertex].sumSq += (Middle - LeftPos + 1) *

            SegTree[Vertex].add * SegTree[Vertex].add +

            2LL * SegTree[Vertex].add * SegTree[2*Vertex].sum;

    SegTree[2*Vertex].sum += (Middle - LeftPos + 1) * SegTree[Vertex].add;

    if (SegTree[2*Vertex].type == NORMAL) SegTree[2*Vertex].type = ADD;

    if (SegTree[2*Vertex+1].type == NORMAL) SegTree[2*Vertex+1].type = ADD;

 

    SegTree[2*Vertex+1].add += SegTree[Vertex].add;

    SegTree[2*Vertex+1].sumSq += (RightPos - Middle) * SegTree[Vertex].add  *

                         SegTree[Vertex].add +

                         2LL * SegTree[Vertex].add * SegTree[2*Vertex+1].sum;

    SegTree[2*Vertex+1].sum += (RightPos - Middle) * SegTree[Vertex].add;

    SegTree[Vertex].add = 0;

    SegTree[Vertex].type = NORMAL;

  }

}

 

void SetValue(int Vertex, int LeftPos, int RightPos, int Left,

              int Right, int Value)

{

  if (Left > Right) return;

  if ((LeftPos == Left) && (RightPos == Right))

  {

    SegTree[Vertex].add = Value;

    SegTree[Vertex].type = SET;

    SegTree[Vertex].sum = (long long)(Right - Left + 1) * Value;

    SegTree[Vertex].sumSq = (long long)(Right - Left + 1) * Value * Value;

    return;

  }

 

  int Middle = (LeftPos + RightPos) / 2;

  Push(Vertex,LeftPos,Middle,RightPos);

 

  SetValue(2*Vertex, LeftPos, Middle, Left, min(Middle,Right), Value);

  SetValue(2*Vertex+1, Middle+1, RightPos, max(Left,Middle+1), Right, Value);

 

  SegTree[Vertex].sum = SegTree[2*Vertex].sum + SegTree[2*Vertex+1].sum;

  SegTree[Vertex].sumSq =

    SegTree[2*Vertex].sumSq + SegTree[2*Vertex+1].sumSq;

}

 

void AddValue(int Vertex, int LeftPos, int RightPos,

              int Left, int Right, int Value)

{

  if (Left > Right) return;

  if ((LeftPos == Left) && (RightPos == Right))

  {

    SegTree[Vertex].add += Value;

    if (SegTree[Vertex].type == NORMAL) SegTree[Vertex].type = ADD;

 

    SegTree[Vertex].sumSq += (long long)(Right - Left + 1) * Value * Value +

                             2LL * Value * SegTree[Vertex].sum;

    SegTree[Vertex].sum += (long long)(Right - Left + 1) * Value;

    return;

  }

 

  int Middle = (LeftPos + RightPos) / 2;

  Push(Vertex,LeftPos,Middle,RightPos);

 

  AddValue(2*Vertex, LeftPos, Middle, Left, min(Middle,Right), Value);

  AddValue(2*Vertex+1, Middle+1, RightPos, max(Left,Middle+1), Right, Value);

 

  SegTree[Vertex].sum = SegTree[2*Vertex].sum + SegTree[2*Vertex+1].sum;

  SegTree[Vertex].sumSq =

    SegTree[2*Vertex].sumSq + SegTree[2*Vertex+1].sumSq;

}

 

long long SumSq(int Vertex, int LeftPos, int RightPos, int Left, int Right)

{

  if (Left > Right) return 0;

  if ((LeftPos == Left) && (RightPos == Right)) return SegTree[Vertex].sumSq;

 

  int Middle = (LeftPos + RightPos) / 2;

  Push(Vertex,LeftPos,Middle,RightPos);

 

  return SumSq(2*Vertex, LeftPos, Middle, Left, min(Middle,Right)) +

         SumSq(2*Vertex+1, Middle+1, RightPos, max(Left,Middle+1), Right);

}

 

int i, n, q, cs, tests, type, l, r, x;

 

int main(void)

{

  scanf("%d",&tests);

  for(cs = 1; cs <= tests; cs++)

  {

    scanf("%d %d",&n,&q);

    for(i = 1; i <= n; i++)

      scanf("%lld",&mas[i]);

 

    build(mas,1,1,n);

    printf("Case %d:\n",cs);

 

    while(q--)

    {

      scanf("%d",&type);

      if (type == 0)

      {

        scanf("%d %d %d",&l,&r,&x);

        SetValue(1,1,n,l,r,x);

      } else

      if (type == 1)    

      {

        scanf("%d %d %d",&l,&r,&x);

        AddValue(1,1,n,l,r,x);

      } else

      {

        scanf("%d %d",&l,&r);

        printf("%lld\n",SumSq(1,1,n,l,r));

      }

    }

  }

  return 0;

}

 

Ðåàëèçàöèÿ àëãîðèòìà – âòîðîé âàðèàíò

 

#include <cstdio>

#include <algorithm>

#define MAX 100010

#define INF 2100000000

using namespace std;

 

struct node

{

  long long sum, sumSq, add, set;

} SegTree[4*MAX];

 

long long mas[MAX];

 

void build(long long *a, int Vertex, int Left, int Right)

{

  if (Left == Right)

  {

    SegTree[Vertex].sum = a[Left];

    SegTree[Vertex].sumSq = 1LL * a[Left] * a[Left];

    SegTree[Vertex].add = 0;

    SegTree[Vertex].set = INF;

  }

  else

  {

    int Middle = (Left + Right) / 2;

 

    build (a, 2*Vertex, Left, Middle);

    build (a, 2*Vertex+1, Middle+1, Right);

 

    SegTree[Vertex].sum = SegTree[2*Vertex].sum + SegTree[2*Vertex+1].sum;

    SegTree[Vertex].sumSq =

      SegTree[2*Vertex].sumSq + SegTree[2*Vertex+1].sumSq;

    SegTree[Vertex].add = 0;

    SegTree[Vertex].set = INF;

  }

}

 

void Push(int Vertex, int LeftPos, int Middle, int RightPos)

{

  if (SegTree[Vertex].set != INF)

  {

    SegTree[2*Vertex].set = SegTree[Vertex].set;

    SegTree[2*Vertex].add = 0;

    SegTree[2*Vertex].sum = (Middle - LeftPos + 1) * SegTree[Vertex].set;

    SegTree[2*Vertex].sumSq =

      (Middle - LeftPos + 1) * SegTree[Vertex].set * SegTree[Vertex].set;

 

    SegTree[2*Vertex+1].set = SegTree[Vertex].set;

    SegTree[2*Vertex+1].add = 0;

    SegTree[2*Vertex+1].sum = (RightPos - Middle) * SegTree[Vertex].set;

    SegTree[2*Vertex+1].sumSq = (RightPos - Middle) *

                                SegTree[Vertex].set * SegTree[Vertex].set;

    SegTree[Vertex].set = INF;

  }

 

  if (SegTree[Vertex].add != 0)

  {

    SegTree[2*Vertex].add += SegTree[Vertex].add;

    SegTree[2*Vertex].sumSq += (Middle - LeftPos + 1) * SegTree[Vertex].add

                               * SegTree[Vertex].add +

                         2LL * SegTree[Vertex].add * SegTree[2*Vertex].sum;

    SegTree[2*Vertex].sum += (Middle - LeftPos + 1) * SegTree[Vertex].add;

 

    SegTree[2*Vertex+1].add += SegTree[Vertex].add;

    SegTree[2*Vertex+1].sumSq += (RightPos - Middle) * SegTree[Vertex].add  *

                                 SegTree[Vertex].add +

                       2LL * SegTree[Vertex].add * SegTree[2*Vertex+1].sum;

    SegTree[2*Vertex+1].sum += (RightPos - Middle) * SegTree[Vertex].add;

    SegTree[Vertex].add = 0;

  }

}

 

void SetValue(int Vertex, int LeftPos, int RightPos,

              int Left, int Right, int Value)

{

  if (Left > Right) return;

  if ((LeftPos == Left) && (RightPos == Right))

  {

    SegTree[Vertex].add = 0;

    SegTree[Vertex].set = Value;

    SegTree[Vertex].sum = (long long)(Right - Left + 1) * Value;

    SegTree[Vertex].sumSq = (long long)(Right - Left + 1) * Value * Value;

    return;

  }

 

  int Middle = (LeftPos + RightPos) / 2;

  Push(Vertex,LeftPos,Middle,RightPos);

 

  SetValue(2*Vertex, LeftPos, Middle, Left, min(Middle,Right), Value);

  SetValue(2*Vertex+1, Middle+1, RightPos, max(Left,Middle+1), Right, Value);

 

  SegTree[Vertex].sum = SegTree[2*Vertex].sum + SegTree[2*Vertex+1].sum;

  SegTree[Vertex].sumSq =

    SegTree[2*Vertex].sumSq + SegTree[2*Vertex+1].sumSq;

}

 

void AddValue(int Vertex, int LeftPos, int RightPos,

              int Left, int Right, int Value)

{

  if (Left > Right) return;

  if ((LeftPos == Left) && (RightPos == Right))

  {

    SegTree[Vertex].add += Value;

    SegTree[Vertex].sumSq += (long long)(Right - Left + 1) * Value * Value +

                             2LL * Value * SegTree[Vertex].sum;

    SegTree[Vertex].sum += (long long)(Right - Left + 1) * Value;

    return;

  }

 

  int Middle = (LeftPos + RightPos) / 2;

  Push(Vertex,LeftPos,Middle,RightPos);

 

  AddValue(2*Vertex, LeftPos, Middle, Left, min(Middle,Right), Value);

  AddValue(2*Vertex+1, Middle+1, RightPos, max(Left,Middle+1), Right, Value);

 

  SegTree[Vertex].sum = SegTree[2*Vertex].sum + SegTree[2*Vertex+1].sum;

  SegTree[Vertex].sumSq =

    SegTree[2*Vertex].sumSq + SegTree[2*Vertex+1].sumSq;

}

 

long long SumSq(int Vertex, int LeftPos, int RightPos, int Left, int Right)

{

  if (Left > Right) return 0;

  if ((LeftPos == Left) && (RightPos == Right)) return SegTree[Vertex].sumSq;

 

  int Middle = (LeftPos + RightPos) / 2;

  Push(Vertex,LeftPos,Middle,RightPos);

 

  return SumSq(2*Vertex, LeftPos, Middle, Left, min(Middle,Right)) +

         SumSq(2*Vertex+1, Middle+1, RightPos, max(Left,Middle+1), Right);

}

 

int i, n, q, cs, tests, type, l, r, x;

 

int main(void)

{

  scanf("%d",&tests);

  for(cs = 1; cs <= tests; cs++)

  {

    scanf("%d %d",&n,&q);

    for(i = 1; i <= n; i++)

      scanf("%lld",&mas[i]);

 

    build(mas,1,1,n);

    printf("Case %d:\n",cs);

 

    while(q--)

    {

      scanf("%d",&type);

      if (type == 0)

      {

        scanf("%d %d %d",&l,&r,&x);

        SetValue(1,1,n,l,r,x);

      } else

      if (type == 1)    

      {

        scanf("%d %d %d",&l,&r,&x);

        AddValue(1,1,n,l,r,x);

      } else

      {

        scanf("%d %d",&l,&r);

        printf("%lld\n",SumSq(1,1,n,l,r));

      }

    }

  }

  return 0;

}